from transformers import BertTokenizer
import torch
from torch import nn
from datasets import load_dataset, DatasetDict, concatenate_datasets
import os
from datasets import load_from_disk
import argparse
import time

start_time = time.time()
##################################################################################################################################################
# hyparameters
##################################################################################################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', default=3, type=int)
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--lr', default=5e-5, type=float)
parser.add_argument('--optimizer', default='ALTO', choices=['sgd', 'adam', 'lamb', 'adamW', 'ALTO', 'adaBelief'])
parser.add_argument('--beta', default=0.9, type=float)
args = parser.parse_args()
##################################################################################################################################################
# dataset
##################################################################################################################################################
# bert tokenizer
tokenizer = BertTokenizer.from_pretrained("....../bert-base-cased")

def load_imdb_dataset(base_path):
    train_pos_files = os.path.join(base_path, "train", "pos", "*.txt")
    train_neg_files = os.path.join(base_path, "train", "neg", "*.txt")
    test_pos_files = os.path.join(base_path, "test", "pos", "*.txt")
    test_neg_files = os.path.join(base_path, "test", "neg", "*.txt")

    train_pos = load_dataset('text', data_files=train_pos_files, split='train')
    train_neg = load_dataset('text', data_files=train_neg_files, split='train')
    test_pos = load_dataset('text', data_files=test_pos_files, split='train')
    test_neg = load_dataset('text', data_files=test_neg_files, split='train')

    train_pos = train_pos.map(lambda examples: {"labels": [1] * len(examples["text"])}, batched=True)
    train_neg = train_neg.map(lambda examples: {"labels": [0] * len(examples["text"])}, batched=True)
    test_pos = test_pos.map(lambda examples: {"labels": [1] * len(examples["text"])}, batched=True)
    test_neg = test_neg.map(lambda examples: {"labels": [0] * len(examples["text"])}, batched=True)

    train_dataset = concatenate_datasets([train_pos, train_neg])
    test_dataset = concatenate_datasets([test_pos, test_neg])

    return DatasetDict({"train": train_dataset, "test": test_dataset})

# set base patch
base_path = "....../dataset/aclImdb"

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

def save_dataset(dataset, save_path):
    dataset.save_to_disk(save_path)

def load_dataset_if_exists(load_path):
    try:
        return load_from_disk(load_path)
    except FileNotFoundError:
        return None

dataset_path = "/cache"

tokenized_datasets = load_dataset_if_exists(dataset_path)

if tokenized_datasets is None:
    dataset = load_imdb_dataset(base_path)
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    save_dataset(tokenized_datasets, dataset_path)

##################################################################################################################################################

from transformers import BertForSequenceClassification
from torch.utils.data.dataloader import default_collate
import torch
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_ids = pad_sequence([torch.tensor(item['input_ids']) for item in batch], batch_first=True)
    attention_mask = pad_sequence([torch.tensor(item['attention_mask']) for item in batch], batch_first=True)
    labels = torch.tensor([item['labels'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

# load bert-base
model = BertForSequenceClassification.from_pretrained("....../bert-base-cased", num_labels=2)

if torch.cuda.device_count() > 1:
    print(f"{torch.cuda.device_count()} GPUs available. Using DataParallel.")
    model = torch.nn.DataParallel(model)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# for param in model.bert.parameters():
#     param.requires_grad = False

from torch.utils.data import DataLoader
#from transformers import AdamW
from sklearn.metrics import precision_score, recall_score, f1_score
from torch.optim import Adam, SGD, AdamW
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from optimizers.lamb import create_lamb_optimizer
from optimizers.ALTO import create_ALTO_optimizer
from adabelief_pytorch import AdaBelief
# collate_fn
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=args.batch_size//torch.cuda.device_count(), shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(tokenized_datasets["test"], batch_size=args.batch_size//torch.cuda.device_count(), collate_fn=collate_fn)


learning_rate = args.lr
if args.optimizer == 'sgd':
    optimizer = SGD(model.parameters(), lr=learning_rate)
elif args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=learning_rate)
elif args.optimizer == 'adamW':
    optimizer = AdamW(model.parameters(), lr=learning_rate)
elif args.optimizer == 'adaBelief':
    optimizer = AdaBelief(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))
elif args.optimizer == 'ALTO':
    optimizer = create_ALTO_optimizer(model, lr=learning_rate, betas=(args.beta, 0.9, 0.99), weight_decay=1e-4)
elif args.optimizer == 'lamb':
    optimizer = create_lamb_optimizer(model, lr=learning_rate, weight_decay=1e-4)
else:
    raise ValueError('Unknown optimizer: {}'.format(args.optimizer))


def compute_accuracy(predictions, labels):
    _, predicted = torch.max(predictions, dim=1)
    correct = (predicted == labels).float()
    return correct.sum().item() / len(correct)

print("start training...")
for epoch in range(args.epochs):
    model.train()
    total_loss = 0
    total_accuracy = 0
    total_samples = 0

    for i, batch in enumerate(train_dataloader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        if torch.cuda.device_count() > 1:
            loss = torch.mean(loss) 
        total_loss += loss.item()

        total_loss += loss.item()
        logits = outputs.logits
        accuracy = compute_accuracy(logits, labels)
        total_accuracy += accuracy * len(labels)
        total_samples += len(labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0: 
            print(f"Epoch {epoch + 1}/{args.epochs}, Step {i + 1}/{len(train_dataloader)}, Loss: {loss.item()}")

    avg_loss = total_loss / len(train_dataloader)
    avg_accuracy = total_accuracy / total_samples
    print(f"Epoch {epoch + 1}/{args.epochs} - Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}")


print("start testing...")
model.eval()
total_correct = 0
total_count = 0
all_predictions = []
all_labels = []

for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    labels = batch['labels']
    
    with torch.no_grad():
        outputs = model(**batch)
        logits = outputs.logits

    correct_predictions = (torch.max(logits, dim=1)[1] == labels).sum().item()
    total_correct += correct_predictions
    total_count += labels.size(0)
    all_predictions.extend(torch.max(logits, dim=1)[1].cpu().numpy())
    all_labels.extend(labels.cpu().numpy())

overall_accuracy = total_correct / total_count
precision = precision_score(all_labels, all_predictions, zero_division=0)
recall = recall_score(all_labels, all_predictions, zero_division=0)
f1 = f1_score(all_labels, all_predictions, zero_division=0)

print(f"Test Metrics:")
print(f" Accuracy: {overall_accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")

end_time = time.time()

total_time = end_time - start_time
print(f"Total running time: {total_time:.2f} seconds")